{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2 Keras 推理\n",
    "\n",
    "\n",
    "```{topic} 主题\n",
    "TVM 对 TensorFlow1 模型的支持并不完善，为了提高开发效率，本节讨论如何将 TensorFlow1 模型转换为 TensorFlow2 的 Keras 模型，最终再将其转换为 ONNX 或者 TFLite 模型。\n",
    "```\n",
    "\n",
    "参考：[migrating_checkpoints](https://www.tensorflow.org/guide/migrate/migrating_checkpoints)\n",
    "\n",
    "下面以模型 [resnet_v2_50](http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz) 为例展示。\n",
    "\n",
    "需要克隆项目 [models](https://github.com/tensorflow/models)，然后执行如下操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-21 16:49:34.172559: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2023-06-21 16:49:34.247466: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2023-06-21 16:49:34.248487: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-06-21 16:49:35.798317: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "m_gpu = -1 # 禁用 GPU\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = str(m_gpu)\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = str(m_gpu)\n",
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError):\n",
    "    tf1 = tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "切换到 `models/research/slim` 目录下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tasks/models/research/slim\n"
     ]
    }
   ],
   "source": [
    "%cd /media/pc/data/lxw/ai/tasks/models/research/slim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将 TF1 升级为 TF2："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nets import resnet_v2\n",
    "import tf_slim as slim\n",
    "\n",
    "class ResnetV2_50(tf.keras.Model):\n",
    "    def __init__(self, trainable=False, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.trainable = trainable\n",
    "\n",
    "    @tf.function(input_signature=[tf.TensorSpec([1, 3, 299, 299], \n",
    "                                                 tf.float32, name=\"data\")])\n",
    "    @tf1.keras.utils.track_tf1_style_variables\n",
    "    def call(self, x):\n",
    "        # x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor\n",
    "        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC\n",
    "        with slim.arg_scope(resnet_v2.resnet_arg_scope()):\n",
    "            logits, end_points = resnet_v2.resnet_v2_50(\n",
    "                x, \n",
    "                num_classes=1001,\n",
    "                global_pool=True,\n",
    "                is_training=self.trainable,\n",
    "                scope=\"resnet_v2_50\"\n",
    "            )\n",
    "        del end_points\n",
    "        return tf.nn.softmax(logits)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-21 16:49:39.660093: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2023-06-21 16:49:39.660172: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: Alg\n",
      "2023-06-21 16:49:39.660183: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: Alg\n",
      "2023-06-21 16:49:39.660370: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 530.30.2\n",
      "2023-06-21 16:49:39.660427: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 530.30.2\n",
      "2023-06-21 16:49:39.660443: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 530.30.2\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "from nets import resnet_v2\n",
    "from tvm_book.data.classification import ImageFolderDataset\n",
    "import tf_slim as slim\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean: tuple[float, ...] = (0.485, 0.456, 0.406),\n",
    "    std: tuple[float, ...] = (1, 1, 1)\n",
    "):\n",
    "    # image = tf.constant(image)\n",
    "    if image.dtype != tf.float32:\n",
    "        image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
    "    if use_grayscale:\n",
    "        image = tf.image.rgb_to_grayscale(image)\n",
    "    if central_crop and central_fraction:\n",
    "        image = tf.image.central_crop(image, central_fraction=central_fraction)\n",
    "    if height and width:\n",
    "        image = tf.expand_dims(image, 0)\n",
    "        image = tf.image.resize(image, [height, width],\n",
    "                                method='bilinear',\n",
    "                                preserve_aspect_ratio=False,\n",
    "                                antialias=False)\n",
    "        image = tf.squeeze(image, [0])\n",
    "    image = tf.subtract(image, mean)\n",
    "    image = tf.divide(image, std)\n",
    "    return image\n",
    "\n",
    "\n",
    "# 预处理\n",
    "root = \"/media/pc/data/lxw/home/data/datasets/ILSVRC/val\"\n",
    "valset = ImageFolderDataset(root)\n",
    "image, label_id = valset[1001]\n",
    "model_dir = 'temp/resnet_v2_50'\n",
    "# remove_dir(model_dir)\n",
    "processed_image = preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean=(0.485, 0.456, 0.406),\n",
    "    std=(1, 1, 1)\n",
    ")\n",
    "np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)\n",
    "np_processed_images = np_processed_images.transpose(0, 3, 1, 2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前向推理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
      "  warnings.warn('`layer.apply` is deprecated and '\n",
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
      "  warnings.warn('`layer.updates` will be removed in a future version. '\n",
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
      "  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS\n"
     ]
    }
   ],
   "source": [
    "model = ResnetV2_50()\n",
    "model(tf.ones(shape=(1, 3, 299, 299), dtype=tf.float32))\n",
    "ckpt = tf.train.Checkpoint(model=model)\n",
    "ckpt_path = \"/media/pc/data/board/arria10/lxw/tests/npu_user_demos/models/resnet50_v2_tf/weight/resnet_v2_50.ckpt\"\n",
    "ckpt.restore(ckpt_path) # 更新模型参数\n",
    "outputs = model(np_processed_images)\n",
    "outputs = outputs.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"resnet_v2_50\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      "=================================================================\n",
      "Total params: 25,615,849\n",
      "Trainable params: 0\n",
      "Non-trainable params: 25,615,849\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印标签信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "真实标签：water ouzel, dipper\n",
      "20: water ouzel, dipper                   \t0.9207783937454224\n",
      "143: oystercatcher, oyster catcher         \t0.014078204520046711\n",
      "141: redshank, Tringa totanus              \t0.0032907347194850445\n",
      "146: albatross, mollymawk                  \t0.0032017454504966736\n",
      "139: ruddy turnstone, Arenaria interpres   \t0.002742304001003504\n"
     ]
    }
   ],
   "source": [
    "from tvm_book.data.imagenet.classification import ImageNet1kAttr\n",
    "\n",
    "imagenet1k_attr = ImageNet1kAttr()\n",
    "sorted_inds = outputs[0].argsort()[::-1]\n",
    "topk = 5\n",
    "print(f\"真实标签：{imagenet1k_attr.classes_long[label_id]}\")\n",
    "for sorted_ind in sorted_inds[:topk]:\n",
    "    label = imagenet1k_attr.classes_long[sorted_ind-1]\n",
    "    print(f\"{sorted_ind-1}: {label.ljust(38)}\\t{outputs[0, sorted_ind]}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将其模型和参数与加载下来："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # model = ResnetV2_50()\n",
    "# inputs = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32, name=\"data\")\n",
    "# outputs = model(inputs)\n",
    "# model2 = tf.keras.Model(inputs=inputs, outputs=outputs, name=\"resnet_v2_50_model\")\n",
    "\n",
    "# model2.save(module_with_signature_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
      "  warnings.warn('`layer.apply` is deprecated and '\n",
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
      "  warnings.warn('`layer.updates` will be removed in a future version. '\n",
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
      "  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS\n"
     ]
    }
   ],
   "source": [
    "module_with_signature_path = \"/tmp/resnet_v2_50_keras\"\n",
    "model.save(module_with_signature_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "imported_with_signatures = tf.saved_model.load(module_with_signature_path)\n",
    "infer = imported_with_signatures.signatures['serving_default']\n",
    "labeling = infer(tf.constant(np_processed_images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "真实标签：water ouzel, dipper\n",
      "20: water ouzel, dipper                   \t0.9207783937454224\n",
      "143: oystercatcher, oyster catcher         \t0.014078204520046711\n",
      "141: redshank, Tringa totanus              \t0.0032907347194850445\n",
      "146: albatross, mollymawk                  \t0.0032017454504966736\n",
      "139: ruddy turnstone, Arenaria interpres   \t0.002742304001003504\n"
     ]
    }
   ],
   "source": [
    "from tvm_book.data.imagenet.classification import ImageNet1kAttr\n",
    "\n",
    "outputs = labeling['output_1'].numpy()\n",
    "imagenet1k_attr = ImageNet1kAttr()\n",
    "sorted_inds = outputs[0].argsort()[::-1]\n",
    "topk = 5\n",
    "print(f\"真实标签：{imagenet1k_attr.classes_long[label_id]}\")\n",
    "for sorted_ind in sorted_inds[:topk]:\n",
    "    label = imagenet1k_attr.classes_long[sorted_ind-1]\n",
    "    print(f\"{sorted_ind-1}: {label.ljust(38)}\\t{outputs[0, sorted_ind]}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换为 ONNX 模型"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Keras 模型转换 ONNX](https://onnxruntime.ai/docs/tutorials/tf-get-started.html)："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-21 16:50:08.883734: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0\n",
      "2023-06-21 16:50:08.883960: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n",
      "2023-06-21 16:50:13.670424: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0\n",
      "2023-06-21 16:50:13.671135: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n"
     ]
    }
   ],
   "source": [
    "import tf2onnx\n",
    "import onnx\n",
    "\n",
    "input_signature = [tf.TensorSpec([None, 3, 299, 299], tf.float32, name=\"data\")]\n",
    "onnx_model, external_tensor_storage = tf2onnx.convert.from_keras(model, input_signature)\n",
    "onnx.save(onnx_model, \"/tmp/resnet_v2_50_tf.onnx\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env\n",
    "from tvm.relay.frontend import from_onnx\n",
    "\n",
    "shape_dict = {\"data\": [1, 3, 299, 299]}\n",
    "\n",
    "graph_def = onnx.load(\"/tmp/resnet_v2_50_tf.onnx\")\n",
    "mod, params = from_onnx(\n",
    "    graph_def,\n",
    "    shape_dict,\n",
    "    freeze_params=True\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
     ]
    }
   ],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    lib = relay.build(mod, \"llvm\", params=params)\n",
    "    \n",
    "inputs_dict = {\"data\": np_processed_images}\n",
    "mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib[\"default\"](tvm.cpu()))\n",
    "mlib_proxy.run(**inputs_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证一致性："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(\n",
    "    labeling['output_1'].numpy(), \n",
    "    mlib_proxy.get_output(0).numpy(),\n",
    "    rtol=1e-07, atol=1e-5\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换为 TFLite 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-21 16:52:14.629868: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.\n",
      "2023-06-21 16:52:14.630049: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.\n",
      "2023-06-21 16:52:14.649710: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/resnet_v2_50_keras\n",
      "2023-06-21 16:52:14.679437: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }\n",
      "2023-06-21 16:52:14.679522: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/resnet_v2_50_keras\n",
      "2023-06-21 16:52:14.765580: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:353] MLIR V1 optimization pass is not enabled\n",
      "2023-06-21 16:52:14.785594: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.\n",
      "2023-06-21 16:52:15.679311: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/resnet_v2_50_keras\n",
      "2023-06-21 16:52:15.933247: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 1283554 microseconds.\n",
      "2023-06-21 16:52:16.844598: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "2023-06-21 16:52:18.902510: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2116] Estimated count of arithmetic ops: 13.119 G  ops, equivalently 6.559 G  MACs\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# Convert the model\n",
    "converter = tf.lite.TFLiteConverter.from_saved_model(module_with_signature_path)\n",
    "tflite_model = converter.convert()\n",
    "\n",
    "# Save the model.\n",
    "with open('temp/resnet_v2_50.tflite', 'wb') as f:\n",
    "    f.write(tflite_model)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载 TFLite 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tflite\n",
    "\n",
    "\n",
    "with open('temp/resnet_v2_50.tflite', \"rb\") as fp:\n",
    "    tflite_model_buf = fp.read()\n",
    "\n",
    "tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)\n",
    "mod, params = relay.frontend.from_tflite(\n",
    "    tflite_model, shape_dict=shape_dict, \n",
    "    dtype_dict={\"data\": \"float32\"}\n",
    ")\n",
    "desired_layouts = {\n",
    "    # 'image.resize2d': ['NCHW'],\n",
    "    'nn.conv2d': ['NCHW', 'default'],\n",
    "    'nn.max_pool2d': ['NCHW', 'default'],\n",
    "    'nn.avg_pool2d': ['NCHW', 'default'],\n",
    "}\n",
    "# NHWC 将布局转换为 NCHW 且移除未使用算子\n",
    "seq = tvm.transform.Sequential([\n",
    "    relay.transform.RemoveUnusedFunctions(),\n",
    "    relay.transform.ConvertLayout(desired_layouts)\n",
    "])\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = seq(mod)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证结果一致性："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "\nNot equal to tolerance rtol=1e-07, atol=1e-05\n\nMismatched elements: 978 / 1001 (97.7%)\nMax absolute difference: 0.92047846\nMax relative difference: 3068.9646\n x: array([[3.429268e-05, 1.693668e-05, 3.029113e-05, ..., 1.208637e-05,\n        9.920573e-06, 2.882769e-05]], dtype=float32)\n y: array([[1.464797e-04, 1.271962e-04, 3.613982e-04, ..., 5.739641e-05,\n        8.909408e-05, 1.503596e-03]], dtype=float32)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 7\u001b[0m\n\u001b[1;32m      5\u001b[0m mlib_proxy \u001b[39m=\u001b[39m tvm\u001b[39m.\u001b[39mcontrib\u001b[39m.\u001b[39mgraph_executor\u001b[39m.\u001b[39mGraphModule(lib[\u001b[39m\"\u001b[39m\u001b[39mdefault\u001b[39m\u001b[39m\"\u001b[39m](tvm\u001b[39m.\u001b[39mcpu()))\n\u001b[1;32m      6\u001b[0m mlib_proxy\u001b[39m.\u001b[39mrun(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs_dict)\n\u001b[0;32m----> 7\u001b[0m np\u001b[39m.\u001b[39;49mtesting\u001b[39m.\u001b[39;49massert_allclose(\n\u001b[1;32m      8\u001b[0m     labeling[\u001b[39m'\u001b[39;49m\u001b[39moutput_1\u001b[39;49m\u001b[39m'\u001b[39;49m]\u001b[39m.\u001b[39;49mnumpy(), \n\u001b[1;32m      9\u001b[0m     mlib_proxy\u001b[39m.\u001b[39;49mget_output(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mnumpy(),\n\u001b[1;32m     10\u001b[0m     rtol\u001b[39m=\u001b[39;49m\u001b[39m1e-07\u001b[39;49m, atol\u001b[39m=\u001b[39;49m\u001b[39m1e-5\u001b[39;49m\n\u001b[1;32m     11\u001b[0m )\n",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/numpy/testing/_private/utils.py:844\u001b[0m, in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)\u001b[0m\n\u001b[1;32m    840\u001b[0m         err_msg \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(remarks)\n\u001b[1;32m    841\u001b[0m         msg \u001b[39m=\u001b[39m build_err_msg([ox, oy], err_msg,\n\u001b[1;32m    842\u001b[0m                             verbose\u001b[39m=\u001b[39mverbose, header\u001b[39m=\u001b[39mheader,\n\u001b[1;32m    843\u001b[0m                             names\u001b[39m=\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mx\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39my\u001b[39m\u001b[39m'\u001b[39m), precision\u001b[39m=\u001b[39mprecision)\n\u001b[0;32m--> 844\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(msg)\n\u001b[1;32m    845\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mValueError\u001b[39;00m:\n\u001b[1;32m    846\u001b[0m     \u001b[39mimport\u001b[39;00m \u001b[39mtraceback\u001b[39;00m\n",
      "\u001b[0;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=1e-07, atol=1e-05\n\nMismatched elements: 978 / 1001 (97.7%)\nMax absolute difference: 0.92047846\nMax relative difference: 3068.9646\n x: array([[3.429268e-05, 1.693668e-05, 3.029113e-05, ..., 1.208637e-05,\n        9.920573e-06, 2.882769e-05]], dtype=float32)\n y: array([[1.464797e-04, 1.271962e-04, 3.613982e-04, ..., 5.739641e-05,\n        8.909408e-05, 1.503596e-03]], dtype=float32)"
     ]
    }
   ],
   "source": [
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    lib = relay.build(mod, \"llvm\", params=params)\n",
    "    \n",
    "inputs_dict = {\"data\": np_processed_images}\n",
    "mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib[\"default\"](tvm.cpu()))\n",
    "mlib_proxy.run(**inputs_dict)\n",
    "np.testing.assert_allclose(\n",
    "    labeling['output_1'].numpy(), \n",
    "    mlib_proxy.get_output(0).numpy(),\n",
    "    rtol=1e-07, atol=1e-5\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{warning}\n",
    "TFLite 转换出现了问题，暂时搁置。\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
